# -*- coding: utf-8 -*-
'''
Created on 2017年4月20日

@author: ZhuJiahui506
'''

import os
import numpy as np
import tensorflow as tf
from brt_embedding.data_preparation import TransPreparaion
import time
from file_utils.file_reader import read_to_1d_list, read_to_2d_list, read_to_1d_list_gbk, read_to_2d_list_gbk


def trans_e(data_model, embedding_size=200, margin=1.0, batch_size=1000, epoch_num=10, negative_type='bern'):
    '''
    TransE建模
    :param data_model: 数据模型(TransPreparaion类对象)
    :param embedding_size: 实体和关系的向量维数
    :param margin: 目标函数中的间隔参数
    :param batch_size: mini-batch片大小
    :param epoch_num: 训练迭代次数
    :param negative_type: 负采样算法类型
    :return 
        entity_embeddings: 所有的实体向量(numpy 2d array)
        relation_embeddings: 所有的关系向量(numpy 2d array)
    '''
    
    alpha = 0.05  # 正则化常数
    
    # 模型定义
    graph = tf.Graph()
    with graph.as_default():
        # 输入变量
        pos_hs = tf.placeholder(tf.int32, shape=[None])
        pos_rs = tf.placeholder(tf.int32, shape=[None])
        pos_ts = tf.placeholder(tf.int32, shape=[None])
        neg_hs = tf.placeholder(tf.int32, shape=[None])
        neg_rs = tf.placeholder(tf.int32, shape=[None])
        neg_ts = tf.placeholder(tf.int32, shape=[None])

        # 模型参数
        embeddings = dict()
        with tf.variable_scope('transe' + 'embedding'):
            embeddings['entity'] = tf.Variable(tf.truncated_normal(
                    [data_model.entity_num, embedding_size], stddev=1.0 / np.sqrt(embedding_size)))
            embeddings['relation'] = tf.Variable(tf.truncated_normal(
                    [data_model.relation_num, embedding_size], stddev=1.0 / np.sqrt(embedding_size)))
            embeddings['entity'] = tf.nn.l2_normalize(embeddings['entity'], 1)  # 归一化
            embeddings['relation'] = tf.nn.l2_normalize(embeddings['relation'], 1)  # 归一化

        # 损失函数计算过程
        phs = tf.nn.embedding_lookup(embeddings['entity'], pos_hs)
        prs = tf.nn.embedding_lookup(embeddings['relation'], pos_rs)
        pts = tf.nn.embedding_lookup(embeddings['entity'], pos_ts)
        nhs = tf.nn.embedding_lookup(embeddings['entity'], neg_hs)
        nrs = tf.nn.embedding_lookup(embeddings['relation'], neg_rs)
        nts = tf.nn.embedding_lookup(embeddings['entity'], neg_ts)

        # 采用L1计算公式
        pos_loss = tf.reduce_sum(tf.abs(phs + prs - pts), 1)
        neg_loss = tf.reduce_sum(tf.abs(nhs + nrs - nts), 1)
        base_loss = tf.reduce_sum(tf.nn.relu(pos_loss + margin - neg_loss))
        
        norm_loss = tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(phs, 2), 1) - 1))
        norm_loss += tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(pts, 2), 1) - 1))
        norm_loss += tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(nhs, 2), 1) - 1))
        norm_loss += tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(nts, 2), 1) - 1))
        norm_loss += tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(prs, 2), 1) - 1))
        norm_loss += tf.reduce_sum(tf.nn.relu(tf.reduce_sum(tf.pow(nrs, 2), 1) - 1))

        loss = base_loss + alpha * norm_loss
        optimizer = tf.train.AdagradOptimizer(0.1).minimize(loss)

    # 模型训练
    with tf.Session(graph=graph) as sess:
        tf.global_variables_initializer().run()  # 初始化
        for e in range(epoch_num):
            batch_num = len(data_model.train_id_triples) // batch_size  # 略去余数
            total_loss = 0
            start = time.time()
            for b in range(batch_num):
                train_batch_positive, train_batch_negative \
                = data_model.generate_batch(data_model.train_id_triples, batch_size, negative_type)
                feed_dict = {
                    pos_hs: [x[0] for x in train_batch_positive],
                    pos_rs: [x[1] for x in train_batch_positive],
                    pos_ts: [x[2] for x in train_batch_positive],
                    neg_hs: [x[0] for x in train_batch_negative],
                    neg_rs: [x[1] for x in train_batch_negative],
                    neg_ts: [x[2] for x in train_batch_negative]
                }
                (_, loss_val) = sess.run([optimizer, loss], feed_dict=feed_dict)
                total_loss += loss_val
            end = time.time()
            print("{}/{}, train_loss = {:.3f}, this epoch time = {:.3f}".format(e, epoch_num, total_loss, end - start))

            # 验证集测试
            validate_batch_positive, validate_batch_negative \
            = data_model.generate_batch(data_model.validate_id_triples, batch_size, negative_type)
            feed_dict = {
                pos_hs: [x[0] for x in validate_batch_positive],
                pos_rs: [x[1] for x in validate_batch_positive],
                pos_ts: [x[2] for x in validate_batch_positive],
                neg_hs: [x[0] for x in validate_batch_negative],
                neg_rs: [x[1] for x in validate_batch_negative],
                neg_ts: [x[2] for x in validate_batch_negative]
            }
            pos_scores = sess.run(pos_loss, feed_dict=feed_dict)
            neg_scores = sess.run(neg_loss, feed_dict=feed_dict)
            accuracy_list = list(map(lambda x: 1 if x[0] < x[1] else 0, zip(pos_scores, neg_scores)))
            print("valid accuracy %f" % (np.sum(accuracy_list) / len(accuracy_list)))
        
        entity_embeddings = embeddings['entity'].eval()
        relation_embeddings = embeddings['relation'].eval()
    
    return entity_embeddings, relation_embeddings


if __name__ == '__main__':
    
    now_directory = os.getcwd()
    root_directory = os.path.dirname(now_directory) + '/'
    read_directory = root_directory + 'dataset/trans_e_data'
    write_directory = root_directory + 'dataset/topic_trans_e_result'
    if (not(os.path.exists(write_directory))):
        os.mkdir(write_directory)
    '''
    entity_list = read_to_1d_list(read_directory + '/entities')
    relation_list = read_to_1d_list(read_directory + '/relations')
    train_triples = read_to_2d_list(read_directory + '/train-entity-facts', '\t')
    validate_triples = read_to_2d_list(read_directory + '/valid-entity-facts', '\t')
    test_triples = read_to_2d_list(read_directory + '/test-entity-facts', '\t')
    '''
    
    entity_list = read_to_1d_list(read_directory + '/topic_list.txt')
    relation_list = read_to_1d_list_gbk(read_directory + '/relation_list.txt')
    train_triples = read_to_2d_list_gbk(read_directory + '/train_triples.txt', ' ')
    validate_triples = read_to_2d_list_gbk(read_directory + '/validate_triples.txt', ' ')
    test_triples = read_to_2d_list_gbk(read_directory + '/test_triples.txt', ' ')
    
    negative_type = 'bern'
    
    embedding_size = 200  # 实体和关系向量维数
    margin = 1.0  # 间隔参数
    batch_size = 2000  # mini-batch大小
    epoch_num = 10  # 总迭代次数

    # 准备训练数据
    data_model = TransPreparaion(entity_list, relation_list, train_triples, validate_triples, test_triples, negative_type)
    
    entity_embeddings, relation_embeddings = trans_e(data_model, embedding_size, margin, batch_size, epoch_num, negative_type)
    
    np.savetxt(write_directory + '/entity_embeddings.txt', entity_embeddings)
    np.savetxt(write_directory + '/relation_embeddings.txt', relation_embeddings)